{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2fb59980-27fb-4a18-8261-8ed024cc0501",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import patches\n",
    "from pathlib import Path\n",
    "import random\n",
    "import fiftyone as fo\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "84ef2c20-ab6c-49f0-97d5-4508a500cf14",
   "metadata": {},
   "outputs": [],
   "source": [
    "problem_name = 'full_spe'\n",
    "method_name = 'pol_res_mot'\n",
    "dataset_name = 'val_{}_{}'.format(problem_name, method_name)\n",
    "default_label = 'thing'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "59b29caa-ccf0-46a6-802f-45bdc29e3001",
   "metadata": {},
   "outputs": [],
   "source": [
    "fo_dataset = fo.load_dataset(dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "af00fba3-ad1b-4917-b336-d0a845c5f6e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating detections...\n",
      " 100% |███████████████| 3549/3549 [4.5m elapsed, 0s remaining, 13.9 samples/s]      \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "       thing       0.01      0.55      0.02     10191\n",
      "\n",
      "   micro avg       0.01      0.55      0.02     10191\n",
      "   macro avg       0.01      0.55      0.02     10191\n",
      "weighted avg       0.01      0.55      0.02     10191\n",
      "\n",
      "Evaluating detections...\n",
      " 100% |███████████████| 3549/3549 [24.8s elapsed, 0s remaining, 121.0 samples/s]      \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "       thing       0.28      0.42      0.33     10191\n",
      "\n",
      "   micro avg       0.28      0.42      0.33     10191\n",
      "   macro avg       0.28      0.42      0.33     10191\n",
      "weighted avg       0.28      0.42      0.33     10191\n",
      "\n",
      "Evaluating detections...\n",
      " 100% |███████████████| 3549/3549 [21.1s elapsed, 0s remaining, 136.3 samples/s]      \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "       thing       0.54      0.41      0.47     10191\n",
      "\n",
      "   micro avg       0.54      0.41      0.47     10191\n",
      "   macro avg       0.54      0.41      0.47     10191\n",
      "weighted avg       0.54      0.41      0.47     10191\n",
      "\n",
      "Evaluating detections...\n",
      " 100% |███████████████| 3549/3549 [19.9s elapsed, 0s remaining, 175.9 samples/s]      \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "       thing       0.72      0.41      0.52     10191\n",
      "\n",
      "   micro avg       0.72      0.41      0.52     10191\n",
      "   macro avg       0.72      0.41      0.52     10191\n",
      "weighted avg       0.72      0.41      0.52     10191\n",
      "\n",
      "Evaluating detections...\n",
      " 100% |███████████████| 3549/3549 [19.8s elapsed, 0s remaining, 180.3 samples/s]      \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "       thing       0.79      0.41      0.54     10191\n",
      "\n",
      "   micro avg       0.79      0.41      0.54     10191\n",
      "   macro avg       0.79      0.41      0.54     10191\n",
      "weighted avg       0.79      0.41      0.54     10191\n",
      "\n"
     ]
    }
   ],
   "source": [
    "field_names = fo_dataset.get_field_schema().keys()\n",
    "itrs = []\n",
    "for name in field_names:\n",
    "    m = re.match('pred_([0-9]+)', name)\n",
    "    if m is not None:\n",
    "        itrs.append(int(m.group(1)))\n",
    "for itr in itrs:\n",
    "    if itr <= 1: continue\n",
    "    results = fo_dataset.evaluate_detections(\n",
    "        f'pred_{itr}',\n",
    "        gt_field='ground_truth',\n",
    "        eval_key=f'eval_{itr}',\n",
    "        compute_mAP=False,\n",
    "    )\n",
    "    metrics = results.metrics()\n",
    "    results.print_report(classes=[default_label])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e31008d0-3bd1-41f4-aeb9-18a5f0560987",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating detections...\n",
      " 100% |███████████████| 3549/3549 [19.5s elapsed, 0s remaining, 187.9 samples/s]      \n"
     ]
    }
   ],
   "source": [
    "results = fo_dataset.evaluate_detections(\n",
    "    f'pred_{itr}',\n",
    "    gt_field='ground_truth',\n",
    "    eval_key=f'eval_{itr}',\n",
    "    compute_mAP=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0ebb4ee6-367d-476a-9378-fdc2b15594bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       thing       0.79      0.41      0.54     10191\n",
      "\n",
      "   micro avg       0.79      0.41      0.54     10191\n",
      "   macro avg       0.79      0.41      0.54     10191\n",
      "weighted avg       0.79      0.41      0.54     10191\n",
      "\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "print(results.print_report(classes=[default_label]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b0d7c275-0650-42ce-9326-dfccad563507",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.37014423502344923,\n",
       " 'precision': 0.790289061023994,\n",
       " 'recall': 0.4104602099892062,\n",
       " 'fscore': 0.5402996641694653,\n",
       " 'support': 10191}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results.metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c8f783d9-7c7c-4bd5-9627-bb075c53f76d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['id', 'filepath', 'tags', 'metadata', 'ground_truth', 'pred_0', 'pred_1', 'pred_5', 'pred_10', 'pred_20', 'pred_100', 'eval_tp', 'eval_fp', 'eval_fn', 'eval_0_tp', 'eval_0_fp', 'eval_0_fn', 'eval_1_tp', 'eval_1_fp', 'eval_1_fn', 'eval_5_tp', 'eval_5_fp', 'eval_5_fn', 'eval_10_tp', 'eval_10_fp', 'eval_10_fn', 'eval_20_tp', 'eval_20_fp', 'eval_20_fn', 'eval_100_tp', 'eval_100_fp', 'eval_100_fn'])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fo_dataset.get_field_schema().keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3efc253f-7f59-46b5-b5d7-1129eeaf8f48",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
